# Functions

# Plot color palette
plot_color_palette <- function(input_cols) {
  
  col_data <- tibble(color = input_cols) %>%
    mutate(color =  fct_inorder(color))
  
  res <- col_data %>%
    ggplot(aes(x = "color", fill = color)) +
    geom_bar() +
    scale_fill_manual(values = input_cols) +
    theme_void()
  
  res
}

# Filter list of Seurat objects for patient, normalize and merge objects 
merge_sobj <- function(sobj_list, sample_order = NULL) {

  res <- merge(
    x = sobj_list[[1]],
    y = sobj_list[2:length(sobj_list)],
    add.cell.ids = names(sobj_list)
  ) %>%
    ScaleData(assay = "RNA") %>%
    ScaleData(assay = "adt") %>%
    FindVariableFeatures(assay = "RNA")
  
  # Set sample order
  res@meta.data <- res@meta.data %>%
    rownames_to_column("cell_ids") %>%
    mutate(orig.ident = fct_relevel(orig.ident, sample_order)) %>%
    column_to_rownames("cell_ids")
  
  res
}

# Run PCA, cluster, and run UMAP using gene expression data
cluster_RNA <- function(sobj_in, assay = "RNA", resolution = 0.6, 
                        dims = 1:40, prefix = "", ...) {
  # Use FindNeighbors to construct a K-nearest neighbors graph based on the euclidean distance in 
  # PCA space, and refine the edge weights between any two cells based on the
  # shared overlap in their local neighborhoods (Jaccard similarity).
  # Use FindClusters to apply modularity optimization techniques such as the Louvain algorithm 
  # (default) or SLM, to iteratively group cells together
  
  # Perform PCA
  # By default only variable features are used for PCA
  res <- sobj_in %>%
    RunPCA(assay = assay, ...) %>%
    AddMetaData(
      metadata = FetchData(., c("PC_1", "PC_2")),
      col.name = str_c(prefix, c("PC_1", "PC_2"))
    )
    
  # Create nearest neighbors graph and find clusters
  res <- res %>%
    FindNeighbors(
      assay     = assay,
      reduction = "pca",
      dims      = dims
    ) %>%
    FindClusters(
      resolution = resolution,
      verbose    = F
    ) %>%
    AddMetaData(
      metadata = Idents(.),
      col.name = str_c(assay, "_clusters")
    )
  
  # Run UMAP, UMAP coordinates will get added to the meta.data by clustifyr
  res <- res %>%
    RunUMAP(
      assay = assay,
      dims  = dims,
      reduction.name = str_c(prefix, "umap"),
      reduction.key  = str_c(prefix, "UMAP_")
    )
  
  res
}

# Fit gaussian mixture model for given signal
fit_GMM <- function(sobj_in, data_column = "adt_ovalbumin") {
  
  # Fit GMM for OVA signal
  data_df <- sobj_in %>%
    FetchData(data_column)
  
  mixmdl <- data_df %>%
    pull(data_column) %>%
    normalmixEM()
  
  # New column names
  ova_names <- c("low", "high")
  comp_names <- c("comp.1", "comp.2")
  
  if (mixmdl$mu[1] > mixmdl$mu[2]) {
    ova_names <- rev(ova_names)
  }
  
  names(comp_names)    <- ova_names
  names(mixmdl$mu)     <- ova_names
  names(mixmdl$sigma)  <- ova_names
  names(mixmdl$lambda) <- ova_names

  # Divide into OVA groups
  res <- data.frame(
    cell_id = rownames(data_df),
    data    = data_df[, data_column],
    mixmdl$posterior
  ) %>%
    dplyr::rename(!!sym(data_column) := data) %>%
    rename(all_of(comp_names)) %>%
    mutate(GMM_grp = if_else(low > 0.5, "low", "high")) %>%
    column_to_rownames("cell_id")
  
  res <- list(
    res    = res,
    mu     = mixmdl$mu,
    sigma  = mixmdl$sigma,
    lambda = mixmdl$lambda
  )
  
  res
}

# Add distribution of GMM component to plot
add_stat_fun <- function(gmm_in, cols_in, key) {
  # dnorm provides density for the normal distribution given the mean and
  # standard deviation. Lambda is used to adjust for the mixture composition.
  # mu: Mean of component
  # sig: Standard deviation of component
  # lam: Lambda of component (mixture weight)
  plot_mix_comps <- function(x, mu, sigma, lam) {
    lam * dnorm(x, mu, sigma)
  }
  
  stat_function(
    geom  = "line",
    fun   = plot_mix_comps,
    args  = list(gmm_in$mu[key], gmm_in$sigma[key], gmm_in$lambda[key]),
    color = cols_in[key],
    lwd   = 1
  )
}

# Overlay feature data on UMAP or tSNE
# Cannot change number of columns when using FeaturePlot with split.by
plot_features <- function(sobj_in, x = "UMAP_1", y = "UMAP_2", feature, pt_size = 0.25,
                          split_id = NULL, plot_cols = c("#fafafa", "#e31a1c"),
                          feat_levels = NULL, split_levels = NULL, min_pct = NULL, 
                          max_pct = NULL, calc_cor = F, lab_size = 3.7, short_feat_name = T,
                          lab_pos = c(0.8, 0.9), lm_line = F, pt_outline = NULL, 
                          show_title = F, ...) {
  
  # Format imput data
  counts <- sobj_in
  short_feat <- feature
  
  if (short_feat_name) {
    short_feat <- feature %>%
      str_remove("\\-[A-Z][0-9]{4}$")
  }
  
  if ("Seurat" %in% class(sobj_in)) {
    vars <- c(x, y, feature)
    
    if (!is.null(split_id)) {
      vars <- c(vars, split_id)
    }

    counts <- sobj_in %>%
      FetchData(vars = vars) %>%
      as_tibble(rownames = "cell_ids")
  }
  
  counts <- counts %>%
    rename(!!sym(short_feat) := !!sym(feature))
  
  # Rename features
  if (!is.null(names(feature))) {
    names(short_feat) <- names(feature)
    
    counts <- counts %>%
      rename(!!!syms(short_feat))
    
    short_feat <- names(short_feat)
  }
  
  if (!is.null(names(x))) {
    counts <- counts %>%
      rename(!!!syms(x))
    
    x <- names(x)
  }
  
  if (!is.null(names(y))) {
    counts <- counts %>%
      renames(!!!syms(y))
    
    y <- names(y)
  }
  
  if (show_title) {    
    counts <- counts %>%
      gather(key, value, !!sym(short_feat))

    short_feat <- "value"
  }
  
  # Set min and max values for feature
  if (!is.null(min_pct) || !is.null(max_pct)) {
    counts <- counts %>%
      mutate(
        pct_rank = percent_rank(!!sym(short_feat)),
        max_val  = ifelse(pct_rank > max_pct, !!sym(short_feat), NA),
        max_val  = min(max_val, na.rm = T),
        min_val  = ifelse(pct_rank < min_pct, !!sym(short_feat), NA),
        min_val  = max(min_val, na.rm = T),
        value    = ifelse(!!sym(short_feat) > max_val, max_val, !!sym(short_feat)),
        value    = ifelse(!!sym(short_feat) < min_val, min_val, !!sym(short_feat))
      )
  }

  # Set feature order
  if (!is.null(feat_levels)) {
    counts <- counts %>%
      mutate(!!sym(short_feat) := fct_relevel(!!sym(short_feat), feat_levels))
  }
  
  # Set facet order
  if (!is.null(split_id)) {
    counts <- counts %>%
      rename(split_id = !!sym(split_id))
    
    if (!is.null(split_levels)) {
      counts <- counts %>%
        mutate(split_id = fct_relevel(split_id, split_levels))
    }
  }
  
  # Calculate correlation
  if (calc_cor) {
    if (!is.null(split_id)) {
      counts <- counts %>%
        group_by(!!sym(split_id))
    }
    
    counts <- counts %>%
      mutate(
        cor_lab = cor(!!sym(x), !!sym(y)),
        cor_lab = round(cor_lab, digits = 2),
        cor_lab = str_c("r = ", cor_lab),
        min_x   = min(!!sym(x)),
        max_x   = max(!!sym(x)),
        min_y   = min(!!sym(y)),
        max_y   = max(!!sym(y)),
        lab_x   = (max_x - min_x) * lab_pos[1] + min_x,
        lab_y   = (max_y - min_y) * lab_pos[1] + min_y
      )
  }
  
  # Create scatter plot
  res <- counts %>%
    arrange(!!sym(short_feat)) %>%
    ggplot(aes(!!sym(x), !!sym(y), color = !!sym(short_feat)))
  
  if (!is.null(pt_outline)) {
    pt_out_legend <- T
    
    if (is.numeric(counts[[short_feat]])) {
      pt_out_legend <- F 
    }
    
    res <- res +
      geom_point(aes(fill = !!sym(short_feat)), size = pt_outline, color = "black", show.legend = pt_out_legend)
  }

  res <- res +
    geom_point(size = pt_size)
  
  # Add regression line
  if (lm_line) {
    res <- res +
      geom_smooth(method = "lm", se = F, color = "black", size = 0.5, linetype = 2)
  }
  
  # Add correlation coefficient label
  if (calc_cor) {
    res <- res +
      geom_text(
        aes(x = lab_x, lab_y, label = cor_lab),
        color = "black",
        size  = lab_size,
        check_overlap = T, 
        show.legend = F
      )
  }
  
  # Show facet-style title
  if (show_title) {
    res <- res +
      facet_wrap(~ key, scales = "free") +
      theme(legend.title = element_blank())
  }
  
  # Set feature colors
  if (is.numeric(counts[[short_feat]])) {
    res <- res +
      scale_color_gradient(low = plot_cols[1], high = plot_cols[2])

  } else {
    res <- res +
      scale_color_manual(values = plot_cols)
  }

  # Split plot into facets
  if (!is.null(split_id)) {
    res <- res +
      facet_wrap(~ split_id, ...)
  }
  
  res
}

# Run gprofiler
run_gprofiler <- function(gene_list, genome = NULL, gmt_id = NULL,
                          dbases = c("GO:BP", "GO:MF", "KEGG"), ...) {
  
  # Check for empty gene list
  if (is_empty(gene_list)) {
    return(NULL)
  }
  
  # Check arguments
  if (is.null(genome) && is.null(gmt_id)) {
    stop("ERROR: Must specifiy genome or gmt_id")
  }
  
  # Retrieve organism name for gProfileR
  if (!is.null(genome)){
    genomes <- list(
      GRCm = "mmusculus",
      GRCh = "hsapiens",
      BDGP = "dmelanogaster"
    )
    
    org <- genome %>% 
      str_remove("[0-9]+$") %>%
      genomes[[.]]
  }
  
  if (!is.null(gmt_id)) {
    org <- gmt_id
    dbases <- NULL
  }
  
  # Run gProfileR
  res <- gene_list %>%
    gost(
      organism      = org,
      sources       = dbases,
      domain_scope  = "annotated",
      significant   = T,
      ...
    )
    
  # Format and sort output data.frame
  res <- res$result %>%
    as_tibble() %>%
    arrange(source, p_value)
  
  res
}

# Create GO bubble plot
create_bubbles <- function(GO_df, plot_colors = theme_cols[c(1:2, 4, 9)],
                           n_terms = 15) {

  # Check for empty inputs
  if (is_empty(GO_df) || nrow(GO_df) == 0) {
    res <- ggplot() +
      geom_blank()
    
    return(res)
  }
  
  # Shorten GO terms and database names
  GO_data <- GO_df %>%
    mutate(
      term_id = str_remove(term_id, "(GO|KEGG):"),
      term_id = str_c(term_id, " ", term_name),
      term_id = str_to_lower(term_id),
      term_id = str_trunc(term_id, 40, "right"),
      source  = fct_recode(
        source,
        "Biological\nProcess" = "GO:BP",
        "Cellular\nComponent" = "GO:CC",
        "Molecular\nFunction" = "GO:MF",
        "KEGG"                = "KEGG"
      )
    )
  
  # Reorder database names
  plot_levels <- c(
    "Biological\nProcess",
    "Cellular\nComponent",
    "Molecular\nFunction",
    "KEGG"
  )
  
  GO_data <- GO_data %>%
    mutate(source = fct_relevel(source, plot_levels))
  
  # Extract top terms for each database
  top_GO <- GO_data %>%
    group_by(source) %>%
    arrange(p_value) %>%
    dplyr::slice(1:n_terms) %>%
    ungroup()
  
  # Create bubble plots
  res <- GO_data %>%
    ggplot(aes(1.25, -log10(p_value), size = intersection_size)) +
    geom_point(color = plot_colors, alpha = 0.5, show.legend = T) +
    geom_text_repel(
      aes(2, -log10(p_value), label = term_id),
      data         = top_GO,
      size         = 2.3,
      direction    = "y",
      hjust        = 0,
      segment.size = NA
    ) +
    xlim(1, 8) +
    labs(y = "-log10(p-value)") +
    theme_info +
    theme(
      axis.title.x    = element_blank(),
      axis.text.x     = element_blank(),
      axis.ticks.x    = element_blank()
    ) +
    facet_wrap(~ source, scales = "free", nrow = 1)
  
  res
}

# Plot percentage of cells in given groups
plot_cell_count <- function(sobj_in, group_id, split_id = NULL, group_order = NULL,
                            fill_id, plot_colors = theme_cols,
                            x_lab = "Cell type", y_lab = "Fraction of cells",
                            bar_pos = "fill", order_count = T, bar_line = 0, ...) {
  
  res <- sobj_in@meta.data %>%
    rownames_to_column("cell_ids") %>%
    mutate(
      group_id := !!sym(group_id),
      fill_id  := !!sym(fill_id)
    )
  
  if (!is.null(group_order)) {
    res <- res %>%
      mutate(group_id = fct_relevel(group_id, group_order))
  }
  
  if (!is.null(split_id)) {
    res <- res %>%
      mutate(split_id := !!sym(split_id))
  }
  
  if (order_count) {
    res <- res %>%
      mutate(fill_id = fct_reorder(fill_id, cell_ids, n_distinct))
  }

  res <- res %>%
    ggplot(aes(group_id, fill = fill_id)) +
    geom_bar(position = bar_pos, size = bar_line, color = "black") +
    scale_fill_manual(values = plot_colors) +
    labs(x = x_lab, y = y_lab) +
    theme_info

  if (!is.null(split_id)) {
    res <- res +
      facet_wrap(~ split_id, ...)
  }
  
  res
}

# Plot confidence intervals for median
create_ci_boxes <- function(input_sobj, group_column, data_column, box_cols) {
  
  # Create data.frame with confidence intervals
  get_boots <- function(data_in, conf = c(0.9, 0.95, 0.99), ...) {
    
    get_ci <- function(conf, boot_in, ...) {
      
      res <- boot.ci(
        boot.out = boot_in,
        conf     = conf, 
        type     = "basic", 
        ...
      )
      
      res <- tibble(
        median = res$t0,
        conf   = str_c(conf * 100, "%"),
        lower  = res$basic[4],
        upper  = res$basic[5]
      )
      
      res
    }
    
    boot_obj <- boot(
      data = data_in,
      statistic = function(x, i) median(x[i]),
      R = 10000
    )
    
    names(conf) <- conf
    
    res <- conf %>%
      map(get_ci, boot_obj) %>%
      bind_rows()
    
    res
  }
  
  box_data <- sobj %>%
    FetchData(c(group_column, data_column)) %>%
    as_tibble(rownames = "cell_id")
  
  conf_df <- box_data %>%
    group_by(!!sym(group_column)) %>%
    summarize(boot_res = list(get_boots(!!sym(data_column)))) %>%
    unnest(cols = boot_res)
  
  # Create scaled error bars
  conf_sizes <- c(
    `90%` = 4,
    `95%` = 3,
    `99%` = 2
  )
  
  conf_alphas <- c(
    `90%` = 1,
    `95%` = 0.5,
    `99%` = 0.25
  )
  
  ova_boxes <- conf_df %>%
    ggplot(aes(`median`, !!sym(group_column), color = !!sym(group_column))) +
    geom_violin(data = box_data, aes(!!sym(data_column), !!sym(group_column)), fill = "#f0f0f0", color = "#f0f0f0", size = 0.2) +
    geom_errorbarh(aes(xmin = lower, xmax = upper, alpha = conf, size = conf), height = 0) +
    # geom_errorbarh(aes(xmin = lower, xmax = upper, alpha = conf), height = 0, size = 3) +
    geom_point(shape = 22, size = 1, fill = "white") +
    scale_color_manual(values = type_cols, guide = F) +
    scale_alpha_manual(values = conf_alphas, guide = F) +
    scale_size_manual(
      name   = "Confidence Level",
      values = conf_sizes,
      guide  = guide_legend(direction = "horizontal", title.position = "top", label.position = "bottom")
    ) +
    scale_x_log10() +
    labs(x = data_column) +
    theme_info +
    theme(
      legend.position = "top",
      legend.title    = element_text(size = 10),
      legend.text     = element_text(size = 10),
      axis.title.y    = element_blank()
    )
}

# Run FindAllMarkers
find_markers <- function(input_sobj, only_pos = T, p_cutoff = 0.05, ...) {
  res <- input_sobj %>%
    FindAllMarkers(only.pos = only_pos, ...) %>%
    as_tibble() %>%
    filter(p_val_adj < p_cutoff)
  
  res
}

# Find cluster markers for each separate cell type
find_group_markers <- function(input_grp, input_sobj, grp_column, clust_column) {
  
  res <- input_sobj %>%
    subset(!!sym(grp_column) == input_grp)

  clusts <- res@meta.data[, clust_column]

  if (n_distinct(clusts) < 2) {
    return(NULL)
  }

  Idents(res) <- res %>%
    FetchData(clust_column)

  res <- res %>%
    find_markers() %>%
    mutate(cell_type = input_grp)

  res
}

# Create reference UMAP for comparisons
create_ref_umap <- function(input_sobj, pt_mtplyr = 1, color_guide, ...) {
  res <- input_sobj %>%
    plot_features(
      pt_size     = 0.1 * pt_mtplyr,
      pt_outline  = 0.4,
      ...
    ) +
    guides(color = color_guide) +
    blank_theme +
    theme(
      legend.position = "top",
      legend.title    = element_blank(),
      legend.text     = element_text(size = 10)
    )
  
  res
}

# Create UMAPs showing marker gene signal
create_marker_umaps <- function(input_sobj, input_markers, umap_col, add_outline = NULL, pt_mtplyr = 1) {
  
  pt_size <- 0.25 * pt_mtplyr
  
  res <- input_markers %>%
    map(~ {
      input_sobj %>%
        plot_features(
          feature    = .x, 
          plot_cols  = c("#fafafa", umap_col),
          pt_outline = add_outline,
          pt_size    = pt_size
        ) +
        ggtitle(.x) +
        blank_theme +
        theme(
          plot.title        = element_text(size = 13),
          legend.position   = "bottom",
          legend.title      = element_blank(),
          legend.text       = element_text(size = 8),
          legend.key.height = unit(0.1, "cm"),
          legend.key.width  = unit(0.3, "cm"),
          axis.title.y      = element_text(size = 13, color = "white"),
          axis.text.y       = element_text(size = 8, color = "white")
        )
    })
  
  res
}

# Create boxplots showing marker gene signal
create_marker_boxes <- function(input_sobj, input_markers, clust_column, box_cols,
                                group = NULL, include_legend = F, all_boxes = F,
                                all_violins = F, order_boxes = T, n_boxes = 10,
                                n_rows = 2, pt_mtplyr = 1, ...) {
  
  # Retrieve and format data for boxplots
  box_data <- input_sobj %>%
    FetchData(c(clust_column, input_markers)) %>%
    as_tibble(rownames = "cell_id") %>%
    mutate(grp = str_remove(!!sym(clust_column), "^[a-zA-Z0-9_]+-"))
  
  input_markers <- input_markers %>%
    str_trunc(9)
  
  # Filter based on input group
  if (!is.null(group)) {
    box_data <- box_data %>%
      filter(grp == group)
  }
  
  # Format data for plots
  box_data <- box_data %>%
    pivot_longer(cols = c(-cell_id, -grp, -!!sym(clust_column)), names_to = "key", values_to = "Counts") %>%
    mutate(
      !!sym(clust_column) := fct_relevel(!!sym(clust_column), names(box_cols)),
      key = str_trunc(key, width = 9, side = "right"),
      key = fct_relevel(key, input_markers)
    )
  
  # Order boxes by mean signal
  if (order_boxes) {
    box_data <- box_data %>%
      mutate(!!sym(clust_column) := fct_reorder(!!sym(clust_column), Counts, mean, .desc = T))
  }
  
  n_clust <- box_data %>%
    pull(clust_column) %>%
    n_distinct()
  
  # Create plots
  n_cols <- ceiling(n_boxes / n_rows)
  
  res <- box_data %>%
    ggplot(aes(!!sym(clust_column), Counts, color = !!sym(clust_column))) + 
    facet_wrap(~ key, ncol = n_cols) +
    scale_color_manual(values = box_cols) +
    theme_info +
    theme(
      panel.spacing.x  = unit(0.7, "cm"),
      strip.background = element_blank(),
      strip.text       = element_text(size = 13),
      legend.position  = "none",
      axis.title.x     = element_blank(),
      axis.title.y     = element_text(size = 13),
      axis.text.x      = element_blank(),
      axis.text.y      = element_text(size = 8),
      axis.ticks.x     = element_blank(),
      axis.line.x      = element_blank()
    )
  
  # Adjust output plot type
  if (n_clust > 6 || all_boxes) {
    res <- res +
      geom_boxplot(aes(fill = !!sym(clust_column), color = !!sym(clust_column)), size = 0, outlier.color = "#f0f0f0", outlier.size = 0.25) +
      stat_summary(fun = "median", geom = "point", shape = 22, size = 1, fill = "white") +
      scale_fill_manual(values = box_cols) +
      theme(...)
    
  } else if (all_violins) {
    res <- res +
      geom_violin(aes(fill = !!sym(clust_column)), size = 0.2) +
      stat_summary(fun = "median", geom = "point", shape = 22, size = 1, fill = "white") +
      scale_fill_manual(values = box_cols) +
      scale_color_manual(values = box_cols) +
      theme(...)
    
  } else {
    pt_size <- 0.3 * pt_mtplyr
    
    res <- res +
      geom_quasirandom(size = pt_size) +
      theme(...)
  }
  
  # Add legend
  if (include_legend) {
    res <- res +
      guides(color = col_guide) +
      theme(legend.position = "top")
  }
  
  # Add blank space for missing facets
  n_keys <- n_distinct(box_data$key)
  
  if (n_keys <= n_cols && n_rows > 1) {
    n_keys <- if_else(n_keys == 1, 2, as.double(n_keys))
    n_cols <- floor(n_cols / n_keys)
    
    res <- res %>%
      plot_grid(
        ncol = n_cols,
        nrow = 2
      )
  }
  
  res
}

# Create figure summarizing marker genes
create_marker_fig <- function(input_sobj, input_markers, input_GO, clust_column, 
                              input_umap, umap_color, fig_heights = c(0.46, 0.3, 0.3), 
                              GO_genome = params$genome, box_colors, n_boxes = 10,
                              umap_outline = NULL, umap_mtplyr = 1, xlsx_name = NULL, 
                              sheet_name = NULL, ...) {
  
  blank_umap <- ggplot() +
    geom_blank() +
    theme_void()
  
  marks_umap  <- blank_umap
  marks_boxes <- blank_umap
  GO_bubbles  <- blank_umap
  
  # Create UMAPs showing marker gene signal
  if (nrow(input_markers) > 0) {
    top_marks <- input_markers$gene %>%
      head(n_boxes)
    
    clust_legend <- get_legend(input_umap)
    
    input_umap <- input_umap +
      theme(legend.position = "none")
    
    marks_umap <- input_sobj %>%
      create_marker_umaps(
        input_markers = head(top_marks, 7),
        umap_col      = umap_color,
        add_outline   = umap_outline,
        pt_mtplyr     = umap_mtplyr
      ) %>%
      append(list(input_umap), .)
    
    marks_umap <- plot_grid(
      plotlist = marks_umap,
      ncol     = 4,
      nrow     = 2,
      align    = "vh",
      axis     = "trbl"
    )
    
    marks_umap <- plot_grid(
      clust_legend, marks_umap,
      rel_heights = c(0.2, 0.9),
      nrow = 2
    )
    
    # Create boxplots showing marker gene signal
    marks_boxes <- input_sobj %>%
      create_marker_boxes(
        input_markers = top_marks,
        clust_column  = clust_column,
        box_cols      = box_colors,
        n_boxes       = n_boxes,
        plot.margin   = unit(c(0.8, 0.2, 0.2, 0.2), "cm"),
        ...
      )
    
    # Create GO term plots
    if (nrow(input_GO) > 0) {
      GO_bubbles <- input_GO %>%
        create_bubbles(plot_colors = umap_color) +
        theme(
          plot.margin      = unit(c(0.8, 0.2, 0.2, 0.2), "cm"),
          strip.background = element_blank(),
          strip.text       = element_text(size = 13),
          axis.title.y     = element_text(size = 13),
          axis.text.y      = element_text(size = 8),
          axis.line.x      = element_blank(),
          legend.position  = "bottom",
          legend.title     = element_blank(),
          legend.text      = element_text(size = 8)
        )
      
      # Write GO terms to excel file 
      if (!is.null(xlsx_name)) {
        input_GO %>%
          dplyr::select(
            term_name,  term_id,
            source,     effective_domain_size,
            query_size, intersection_size,
            p_value,    significant 
          ) %>%
          arrange(source, p_value) %>%
          write.xlsx(
            file      = str_c(xlsx_name, "_GO.xlsx"),
            sheetName = sheet_name,
            append    = T
          )
      }
    }
    
    # Write markers to excel file
    if (!is.null(xlsx_name)) {
      input_markers %>%
        write.xlsx(
          file      = str_c(xlsx_name, "_markers.xlsx"),
          sheetName = sheet_name,
          append    = T
        )
    }
  }
  
  # Create final figure
  res <- plot_grid(
    marks_umap, marks_boxes, GO_bubbles,
    rel_heights = fig_heights,
    ncol        = 1,
    align       = "v",
    axis        = "rl"
  )
  
  if (nrow(input_markers) < n_boxes) {
    res <- plot_grid(
      marks_umap, marks_boxes, GO_bubbles,
      rel_heights = fig_heights,
      ncol        = 1
    )
  }
  
  res
}

# Filter clusters and set cluster order
set_cluster_order <- function(input_cols, input_marks, n_cutoff = 5) {
  input_marks <- input_marks %>%
    group_by(cluster) %>%
    filter(n() >= n_cutoff) %>%
    ungroup()
  
  marks <- unique(input_marks$cluster)
  res   <- names(input_cols)
  res   <- res[res %in% marks]
  
  res
}
  
# Create v1 panel for marker genes
create_marker_panel_v1 <- function(input_sobj, input_cols, input_umap = NULL, clust_column, order_boxes = T,
                                   color_guide = guide_legend(override.aes = list(size = 3.5, shape = 16)),
                                   uniq_GO = F, umap_mtplyr = 6, xlsx_name = NULL, ...) {
  
  # Set point size
  umap_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr, 1)
  ref_mtplyr <- if_else(umap_mtplyr == 1, umap_mtplyr, umap_mtplyr * 2.5)
  
  # Find marker genes
  Idents(input_sobj) <- input_sobj %>%
    FetchData(clust_column)
  
  markers <- find_markers(input_sobj)
  
  # Find GO terms
  GO_df <- markers %>%
    group_by(cluster) %>%
    do({
      arrange(., p_val_adj) %>%
        pull(gene) %>%
        run_gprofiler(
          genome = params$genome,
          ordered_query = T
        )
    }) %>%
    ungroup()
  
  if (uniq_GO && nrow(GO_df) > 0) {
    GO_df <- GO_df %>%
      group_by(term_id) %>%
      filter(n() == 1) %>%
      ungroup()
  }
  
  # Set cluster order based on order of input_cols
  fig_clusters <- input_cols %>%
    set_cluster_order(markers)
  
  # Create figures
  for (i in seq_along(fig_clusters)) {
    cat("\n#### ", fig_clusters[i], "\n", sep = "")
    
    # Filter markers and GO terms
    clust <- fig_clusters[i]
    
    fig_marks <- markers %>%
      filter(cluster == clust)
    
    fig_GO <- GO_df %>%
      filter(cluster == clust)
    
    # Create reference umap
    ref_umap <- input_umap
    umap_col <- input_cols[clust]
    
    if (is.null(input_umap)) {
      umap_levels <- input_cols[names(input_cols) != clust]
      umap_levels <- names(c(umap_levels, umap_col))
      
      ref_umap <- input_sobj %>%
        create_ref_umap(
          feature     = clust_column,
          plot_cols   = input_cols,
          feat_levels = umap_levels,
          pt_mtplyr   = ref_mtplyr,
          color_guide = color_guide
        )
    }
    
    # Create panel
    marker_fig <- input_sobj %>%
      create_marker_fig(
        input_markers = fig_marks,
        input_GO      = fig_GO,
        clust_column  = clust_column,
        input_umap    = ref_umap,
        umap_color    = umap_col,
        box_colors    = input_cols,
        order_boxes   = order_boxes,
        umap_mtplyr   = umap_mtplyr,
        xlsx_name     = xlsx_name,
        sheet_name    = clust,
        ...
      )
    
    cat(nrow(fig_marks), "marker genes were identified,", nrow(fig_GO), "GO terms were identified.")
    print(marker_fig)
    cat("\n\n---\n\n<br>\n\n<br>\n\n")
  }
}

# Create v2 panel that splits plots into groups
create_marker_panel_v2 <- function(input_sobj, input_markers, input_cols, grp_column, clust_column, 
                                   color_guide = guide_legend(override.aes = list(size = 3.5, shape = 16)), 
                                   uniq_GO = F, umap_mtplyr = 6, xlsx_name = NULL, ...) {
  
  # Set point size
  umap_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr, 1)
  ref_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr * 2.5, 1)
  
  # Figure colors and order
  fig_clusters <- input_cols %>%
    set_cluster_order(input_markers)
  
  # Find GO terms
  GO_df <- input_markers %>%
    group_by(cluster) %>%
    do({
      arrange(., p_val_adj) %>%
        pull(gene) %>%
        run_gprofiler(
          genome = params$genome,
          ordered_query = T
        )
    }) %>%
    ungroup()
  
  if (uniq_GO && nrow(GO_df) > 0) {
    GO_df <- GO_df %>%
      group_by(term_id) %>%
      filter(n() == 1) %>%
      ungroup()
  }
  
  # Create figures
  for (i in seq_along(fig_clusters)) {
    cat("\n#### ", fig_clusters[i], "\n", sep = "")
    
    # Filter markers and GO terms
    clust <- fig_clusters[i]
    
    fig_marks <- input_markers %>%
      filter(cluster == clust)
    
    fig_GO <- GO_df %>%
      filter(cluster == clust)
    
    # Set colors
    umap_col <- input_cols[clust]
    
    group <- clust %>%
      str_remove("^[a-zA-Z0-9_]+-")
    
    grp_regex <- str_c("-", group, "$") %>%
      str_replace("\\+", "\\\\+")            # include this to escape "+" in names
    
    fig_cols <- input_cols[grepl(grp_regex, names(input_cols))]
    fig_cols <- c( "Other" = "#fafafa", fig_cols)
    
    # Create reference UMAP
    ref_umap <- input_sobj %>%
      FetchData(c("UMAP_1", "UMAP_2", grp_column, clust_column)) %>%
      as_tibble(rownames = "cell_id") %>%
      mutate(!!sym(clust_column) := if_else(
        !!sym(grp_column) != group, 
        "Other", 
        !!sym(clust_column)
      )) %>%
      create_ref_umap(
        feature     = clust_column,
        plot_cols   = fig_cols,
        feat_levels = names(fig_cols),
        pt_mtplyr   = ref_mtplyr,
        color_guide = color_guide
      )
    
    # Create panel
    marker_fig <- input_sobj %>%
      create_marker_fig(
        input_markers = fig_marks,
        input_GO      = fig_GO,
        clust_column  = clust_column,
        input_umap    = ref_umap,
        umap_color    = umap_col,
        box_colors    = fig_cols,
        group         = group,
        umap_mtplyr   = umap_mtplyr,
        xlsx_name     = xlsx_name,
        sheet_name    = clust,
        ...
      )
    
    cat(nrow(fig_marks), "marker genes were identified.", nrow(fig_GO), "GO terms were identified.")
    print(marker_fig)
    cat("\n\n---\n\n<br>\n\n<br>\n\n")
  }
}

# Create panels for manuscript
create_paper_figures <- function(input_sobj, input_cols, summary_fig = NULL, input_umap = NULL, clust_column,
                                 color_guide = guide_legend(override.aes = list(size = 3.5, shape = 16)),
                                 order_boxes = T, uniq_GO = F, umap_mtplyr = 6, xlsx_name = NULL, ...) {
  
  # Set point size
  umap_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr, 1)
  ref_mtplyr <- if_else(umap_mtplyr == 1, umap_mtplyr, umap_mtplyr * 2.5)
  
  # Find marker genes
  Idents(input_sobj) <- input_sobj %>%
    FetchData(clust_column)
  
  markers <- find_markers(input_sobj)
  
  # Find GO terms
  GO_df <- markers %>%
    group_by(cluster) %>%
    do({
      arrange(., p_val_adj) %>%
        pull(gene) %>%
        run_gprofiler(
          genome = params$genome,
          ordered_query = T
        )
    }) %>%
    ungroup()
  
  if (uniq_GO && nrow(GO_df) > 0) {
    GO_df <- GO_df %>%
      group_by(term_id) %>%
      filter(n() == 1) %>%
      ungroup()
  }
  
  # Set cluster order based on order of input_cols
  fig_clusters <- input_cols %>%
    set_cluster_order(markers)
  
  # Create figures
  for (i in seq_along(fig_clusters)) {
    cat("\n#### ", fig_clusters[i], "\n", sep = "")
    
    # Filter markers and GO terms
    clust <- fig_clusters[i]
    
    fig_marks <- markers %>%
      filter(cluster == clust)
    
    fig_GO <- GO_df %>%
      filter(cluster == clust)
    
    # Create reference umap
    ref_umap <- input_umap
    umap_col <- input_cols[clust]
    
    if (is.null(input_umap)) {
      umap_levels <- input_cols[names(input_cols) != clust]
      umap_levels <- names(c(umap_levels, umap_col))
      
      ref_umap <- input_sobj %>%
        create_ref_umap(
          feature     = clust_column,
          plot_cols   = input_cols,
          feat_levels = umap_levels,
          pt_mtplyr   = ref_mtplyr,
          color_guide = color_guide
        )
    }
    
    # Create panel
    marker_fig <- input_sobj %>%
      create_marker_fig(
        input_markers = fig_marks,
        input_GO      = fig_GO,
        clust_column  = clust_column,
        input_umap    = ref_umap,
        umap_color    = umap_col,
        box_colors    = input_cols,
        order_boxes   = order_boxes,
        umap_mtplyr   = umap_mtplyr,
        xlsx_name     = xlsx_name,
        sheet_name    = clust,
        ...
      )
    
    if (!is.null(summary_fig)) {
      marker_fig <- plot_grid(
        summary_fig, marker_fig,
        rel_heights = c(0.3, 0.7),
        ncol  = 1,
        align = "vh",
        axis  = "trbl"
      )
    }
    
    print(marker_fig)
    cat("\n\n---\n\n<br>\n\n<br>\n\n")
  }
}


# Default chunk options
knitr::opts_chunk$set(message = F, warning = F)

# Load packages
R_packages <- c(
  "tidyverse",  "Seurat",
  "gprofiler2", "knitr",
  "cowplot",    "ggbeeswarm",
  "ggrepel",    "RColorBrewer",
  "xlsx",       "colorblindr",
  "ggforce",    "broom",
  "mixtools",   "clustifyr",
  "boot"
)

for (package in R_packages) {
  library(package, character.only = T)
}


# ggplot2 themes
theme_info <- theme_cowplot() +
  theme(
    plot.title       = element_text(face = "plain", size = 20),
    strip.background = element_rect(fill = "#fafafa"),
    strip.text       = element_text(face = "plain")
  )

umap_theme <- theme_info +
  theme(
    axis.text  = element_blank(),
    axis.ticks = element_blank()
  )

blank_theme <- umap_theme +
  theme(
    axis.line  = element_blank(),
    axis.title = element_blank()
  )

# Legend guides
col_guide <- guide_legend(override.aes = list(size = 3.5, shape = 16))

outline_guide <- guide_legend(override.aes = list(
  size   = 3.5,
  shape  = 21,
  color  = "black",
  stroke = 0.25
))

# Base color palettes
base_cols <- c(
  "#225ea8",  # blue
  "#e31a1c",  # red
  "#238443",  # green
  "#ec7014",  # orange
  "#6a51a3",  # purple
  "#c51b7d",  # pink
  "#8c510a",  # brown
  "#217D87",  # teal, darken("#41b6c4", 0.3)
  "#F0E442",  # yellow, palette_OkabeIto[4]
  "#000000"   # black
)

base_cols_paired <- base_cols %>%
  map(~ {
    .x %>%
        lighten(0.25) %>%
        desaturate(0.2) %>%
        c(.x)
    })

names(base_cols_paired) <- base_cols

base_cols <- base_cols %>%
  lighten(0.25) %>%
  desaturate(0.2) %>%
  c(base_cols, .)

# Okabe Ito color palettes
ito_cols <- c(
  palette_OkabeIto[1:4], "#d7301f", 
  palette_OkabeIto[5:6], "#6a51a3", 
  palette_OkabeIto[7:8]
)

ito_cols_paired <- ito_cols %>%
  map(~ c(.x, darken(.x, 0.3)))

names(ito_cols_paired) <- ito_cols 

ito_cols <- ito_cols %>%
  darken(0.4) %>% 
  c(ito_cols, ., "#000000")

# Set color palette
theme_cols <- base_cols
paired_cols <- base_cols_paired

theme_cols <- ito_cols
paired_cols <- ito_cols_paired
# Load Seurat objects
sobjs <- rds_files %>%
  imap(~ {
    tm        <- str_split(.y, pattern = "_", )[[1]][1]
    type_colm <- str_split(.y, pattern = "_")[[1]][2]
    type_colm <- str_c("cell_type", type_colm)
    res       <- read_rds(.x)
    
    res@meta.data <- res@meta.data %>%
      rownames_to_column("cell_id") %>%
      mutate(
        orig.ident = tm,
        type = !!sym(type_colm),
        type = str_replace(type, " ", "_"),
        type = str_c(tm, "_", type)
      ) %>%
      column_to_rownames("cell_id")
    
    res
  })

# Create Clustifyr reference
subtype_ref <- sobjs %>%
  merge_sobj() %>%
  seurat_ref("type")

# Load external data
sobj_ext <- ext_files %>%
  read.csv() %>%
  column_to_rownames("X") %>%
  CreateSeuratObject()

# ext_meta <- read_tsv("../ext_data/GSE137710_mouse_spleen_cell_metadata_4464x9.tsv.gz") %>%
#   group_by(cell_ID) %>%
#   filter(n() == 1) %>%
#   column_to_rownames("cell_ID")
# 
# sobj_ext <- ext_files %>%
#   read.table() %>%
#   t() %>%
#   CreateSeuratObject(meta.data = ext_meta)

sobj_ext[["percent_mt"]] <- sobj_ext %>%
  PercentageFeatureSet(pattern = "^mt-")

sobj_ext <- sobj_ext %>%
  subset(
    nFeature_RNA > 200 &
    nFeature_RNA < 5000 &
    percent_mt < 15
  )

sobj_ext@meta.data <- sobj_ext@meta.data %>%
  rownames_to_column("cell_id") %>%
  mutate(orig.ident = "ext_data") %>%
  column_to_rownames("cell_id")

# Normalize counts
sobj_ext <- sobj_ext %>%
  NormalizeData(normalization.method = "LogNormalize") %>%
  ScaleData(assay = "RNA") %>%
  FindVariableFeatures(asssay = "RNA")

# Cluster
sobj_ext <- sobj_ext %>%
  cluster_RNA(
    assay      = "RNA",
    resolution = 0.6,
    dims       = 1:40
  ) %>%
  AddMetaData(
    metadata = Embeddings(., reduction = "umap"),
    col.name = c("UMAP_1", "UMAP_2")
  )

# Clustify using our data as reference
sobj_ext <- sobj_ext %>%
  clustify(
    cluster_col = "seurat_clusters",
    ref_mat = subtype_ref
  )

# Create list of all sobjs
all_sobjs <- c(sobjs, ext_data = sobj_ext)

# Merge objects and run PCA, UMAP, clustering
sobj_merge <- all_sobjs %>%
  merge_sobj(sample_order = names(sobjs)) %>%
  cluster_RNA(
    assay      = "RNA",
    resolution = 1,
    dims       = 1:40
  ) %>%
  AddMetaData(
    metadata = Embeddings(., reduction = "umap"),
    col.name = c("merge_UMAP_1", "merge_UMAP_2")
  )

# Integrate samples
sobj_int <- sobj_merge %>%
  SplitObject(split.by = "orig.ident") %>%
  map(FindVariableFeatures) %>%
  FindIntegrationAnchors(dims = 1:40) %>%
  IntegrateData(dims = 1:40) %>%
  ScaleData(assay = "integrated") %>%
  cluster_RNA(
    assay      = "integrated",
    resolution = 1,
    features   = rownames(.),
    prefix     = "int_"
  ) %>%
  AddMetaData(
    metadata = Embeddings(., reduction = "int_umap"),
    col.name = c("UMAP_1", "UMAP_2")
  )

sobj_int@meta.data <- sobj_int@meta.data %>%
  rownames_to_column("cell_id") %>%
  mutate(orig_type = str_c(orig.ident, "-", type)) %>%
  column_to_rownames("cell_id")

d2_LEC, d14_LEC, GSM4306928_LEC_10…

Marker genes were identified by comparing each sample. Gene expression signal is shown for the top marker genes (top, middle). GO terms were identified for marker genes, the top terms are labeled, the size of each point indicates the number of overlapping genes (bottom).

d2

32 marker genes were identified, 78 GO terms were identified.




d14

157 marker genes were identified, 243 GO terms were identified.




ext_data

24 marker genes were identified, 78 GO terms were identified.




Marker genes were identified by dividing cells based on subtype and comparing each sample. Gene expression signal is shown for the top marker genes (top, middle). GO terms were identified for marker genes, the top terms are labeled, the size of each point indicates the number of overlapping genes (bottom).

d14-d14_Ceiling_LECs

92 marker genes were identified. 31 GO terms were identified.




ext_data-d14_Ceiling_LECs

20 marker genes were identified. 5 GO terms were identified.




d14-d14_Floor_LECs

177 marker genes were identified. 350 GO terms were identified.




ext_data-d14_Floor_LECs

49 marker genes were identified. 55 GO terms were identified.